| x | y | |
|---|---|---|
| 0 | 0.000000 | 3.113179 |
| 1 | 0.010101 | 3.774512 |
| 2 | 0.020202 | 4.045562 |
| 3 | 0.030303 | 3.207971 |
| 4 | 0.040404 | 3.336638 |
| ... | ... | ... |
| 95 | 0.959596 | 1.951793 |
| 96 | 0.969697 | 0.224769 |
| 97 | 0.979798 | -0.387220 |
| 98 | 0.989899 | 1.304032 |
| 99 | 1.000000 | 0.174600 |
100 rows × 2 columns
Lecture 25
X = d.x.to_numpy().reshape(-1,1)
y = d.y.to_numpy()
with pm.Model() as model:
l = pm.Gamma("l", alpha=2, beta=1)
s = pm.HalfCauchy("s", beta=5)
nug = pm.HalfCauchy("nug", beta=5)
cov = s**2 * pm.gp.cov.ExpQuad(input_dim=1, ls=l)
gp = pm.gp.Marginal(cov_func=cov)
y_ = gp.marginal_likelihood(
"y", X=X, y=y, sigma=nug
)Beyond the ability of PyMC to use different sampling steps - it can also use different sampler algorithm implementations to run your model.
These can be changed via the nuts_sampler argument which currently supports:
pymc - standard sampler uses pymc’s C backend
blackjax - uses the blackjax library which is a collection of samplers written for JAX
numpyro - probabilistic programming library for pyro built using JAX
nutpie - provides a wrapper to the nuts-rs Rust library (slight variation on NUTS compared to numpy & stan)
6.26 s ± 83.3 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
%%timeit -r 3
with model:
post_jax = pm.sample(nuts_sampler="blackjax", chains=4, progressbar=False)4.16 s ± 120 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)
At the moment both Python & R offer two variants of Stan:
pystan & RStan - native language interface to the underlying Stan C++ libraries
CmdStanPy & CmdStanR - are wrappers around the CmdStan command line interface
model.stan)Any of the above tools will require a modern C++ toolchain (C++17 support required).
Stan code is divided up into specific blocks depending on usage - all of the following blocks are optional but the ordering has to match what is given below.
functions {
// user-defined functions
}
data {
// declares the required data for the model
}
transformed data {
// allows the definition of constants and transforms of the data
}
parameters {
// declares the model’s parameters
}
transformed parameters {
// allows variables to be defined in terms of data and parameters
}
model {
// defines the log probability function
}
generated quantities {
// allows derived quantities based on parameters, data, and random number generation
}CmdStanMCMC: model=bernoulli chains=4['method=sample', 'algorithm=hmc', 'adapt', 'engaged=1']
csv_files:
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_1.csv
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_2.csv
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_3.csv
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_4.csv
output_files:
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_0-stdout.txt
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_1-stdout.txt
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_2-stdout.txt
/var/folders/v7/wrxd7cdj6l5gzr0191__m9lr0000gr/T/tmpv2zs_85x/bernoulli6pe85w52/bernoulli-20250415121940_3-stdout.txt
{'theta': array([0.45722, 0.36427, 0.36427, 0.52572, 0.31933, 0.2019 , 0.38248, 0.16555, 0.29055, 0.33254, 0.28832, 0.32835, 0.19974, 0.37581, 0.37581, 0.3844 , 0.19428, 0.36735, 0.38099, 0.16314, 0.11972,
0.18252, 0.13986, 0.15708, 0.12406, 0.31417, 0.37685, 0.1857 , 0.27825, 0.23875, ..., 0.43157, 0.33939, 0.18887, 0.06114, 0.36835, 0.37939, 0.19438, 0.26716, 0.4212 , 0.30349, 0.25675,
0.20348, 0.17883, 0.17883, 0.12666, 0.28372, 0.09175, 0.1248 , 0.06877, 0.14584, 0.10053, 0.11008, 0.31683, 0.31493, 0.2744 , 0.14932, 0.16615, 0.09647, 0.09929, 0.07015], shape=(4000,))}
Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.
Checking sampler transitions for divergences.
No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.
Processing complete, no problems detected.
Lec25/gp.stan
data {
int<lower=1> N;
array[N] real x;
vector[N] y;
}
transformed data {
array[N] real xn = to_array_1d(x);
vector[N] zeros = rep_vector(0, N);
}
parameters {
real<lower=0> l;
real<lower=0> s;
real<lower=0> nug;
}
model {
// Covariance
matrix[N, N] K = gp_exp_quad_cov(x, s, l);
matrix[N, N] L = cholesky_decompose(add_diag(K, nug^2));
// priors
l ~ gamma(2, 1);
s ~ cauchy(0, 5);
nug ~ cauchy(0, 1);
// model
y ~ multi_normal_cholesky(rep_vector(0, N), L);
}12:19:41 - cmdstanpy - INFO - CmdStan start processing
12:19:41 - cmdstanpy - INFO - Chain [1] start processing
12:19:41 - cmdstanpy - INFO - Chain [2] start processing
12:19:41 - cmdstanpy - INFO - Chain [3] start processing
12:19:41 - cmdstanpy - INFO - Chain [4] start processing
12:19:43 - cmdstanpy - INFO - Chain [2] done processing
12:19:43 - cmdstanpy - INFO - Chain [1] done processing
12:19:43 - cmdstanpy - INFO - Chain [3] done processing
12:19:43 - cmdstanpy - INFO - Chain [4] done processing
12:19:43 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp.stan', line 18, column 2 to column 58)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp.stan', line 18, column 2 to column 58)
Exception: gp_exp_quad_cov: length_scale is 0, but must be positive! (in 'gp.stan', line 17, column 2 to column 44)
Exception: gp_exp_quad_cov: length_scale is 0, but must be positive! (in 'gp.stan', line 17, column 2 to column 44)
Exception: gp_exp_quad_cov: length_scale is 0, but must be positive! (in 'gp.stan', line 17, column 2 to column 44)
Consider re-running with show_console=True if the above output is unclear!
| Mean | MCSE | StdDev | MAD | 5% | 50% | 95% | ESS_bulk | ESS_tail | R_hat | |
|---|---|---|---|---|---|---|---|---|---|---|
| lp__ | -42.959300 | 0.030618 | 1.228470 | 1.044420 | -45.378100 | -42.665800 | -41.588600 | 1620.19 | 2108.61 | 1.00338 |
| l | 0.106533 | 0.000580 | 0.024411 | 0.021621 | 0.072973 | 0.102830 | 0.153564 | 2002.96 | 1784.41 | 1.00357 |
| s | 2.199350 | 0.022507 | 0.823502 | 0.594738 | 1.321590 | 2.007950 | 3.681860 | 1845.11 | 1422.63 | 1.00192 |
| nug | 0.732305 | 0.001148 | 0.057479 | 0.057420 | 0.645550 | 0.728543 | 0.832352 | 2597.64 | 2582.88 | 1.00106 |
Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.
Checking sampler transitions for divergences.
No divergent transitions found.
Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.
Rank-normalized split effective sample size satisfactory for all parameters.
Rank-normalized split R-hat values satisfactory for all parameters.
Processing complete, no problems detected.
The nutpie package can also be used to compile and run stan models, it uses a package called bridgestan to interface with stan.
import nutpie
m = nutpie.compile_stan_model(filename="Lec25/gp.stan")
m = m.with_data(x=d["x"],y=d["y"],N=len(d["x"]))
gp_fit_nutpie = nutpie.sample(m, chains=4)Sampler Progress
Total Chains: 4
Active Chains: 0
Finished Chains: 4
Sampling for now
Estimated Time to Completion: now
| Progress | Draws | Divergences | Step Size | Gradients/Draw |
|---|---|---|---|---|
| 1400 | 0 | 0.82 | 3 | |
| 1400 | 0 | 0.80 | 3 | |
| 1400 | 0 | 0.80 | 1 | |
| 1400 | 0 | 0.79 | 3 |
Lec25/gp2.stan
functions {
// From https://mc-stan.org/docs/stan-users-guide/gaussian-processes.html#predictive-inference-with-a-gaussian-process
vector gp_pred_rng(array[] real x2,
vector y1,
array[] real x1,
real alpha,
real rho,
real sigma,
real delta) {
int N1 = rows(y1);
int N2 = size(x2);
vector[N2] f2;
{
matrix[N1, N1] L_K;
vector[N1] K_div_y1;
matrix[N1, N2] k_x1_x2;
matrix[N1, N2] v_pred;
vector[N2] f2_mu;
matrix[N2, N2] cov_f2;
matrix[N2, N2] diag_delta;
matrix[N1, N1] K;
K = gp_exp_quad_cov(x1, alpha, rho);
for (n in 1:N1) {
K[n, n] = K[n, n] + square(sigma);
}
L_K = cholesky_decompose(K);
K_div_y1 = mdivide_left_tri_low(L_K, y1);
K_div_y1 = mdivide_right_tri_low(K_div_y1', L_K)';
k_x1_x2 = gp_exp_quad_cov(x1, x2, alpha, rho);
f2_mu = (k_x1_x2' * K_div_y1);
v_pred = mdivide_left_tri_low(L_K, k_x1_x2);
cov_f2 = gp_exp_quad_cov(x2, alpha, rho) - v_pred' * v_pred;
diag_delta = diag_matrix(rep_vector(delta, N2));
f2 = multi_normal_rng(f2_mu, cov_f2 + diag_delta);
}
return f2;
}
}
data {
int<lower=1> N; // number of observations
array[N] real x; // univariate covariate
vector[N] y; // target variable
int<lower=1> Np; // number of test points
array[Np] real xp; // univariate test points
}
transformed data {
real delta = 1e-9;
}
parameters {
real<lower=0> l;
real<lower=0> s;
real<lower=0> nug;
}
model {
// Covariance
matrix[N, N] K = gp_exp_quad_cov(x, s, l);
matrix[N, N] L = cholesky_decompose(add_diag(K, nug^2));
// priors
l ~ gamma(2, 1);
s ~ cauchy(0, 5);
nug ~ cauchy(0, 1);
// model
y ~ multi_normal_cholesky(rep_vector(0, N), L);
}
generated quantities {
// function scaled back to the original scale
vector[Np] f = gp_pred_rng(xp, y, x, s, l, nug, delta);
}12:20:09 - cmdstanpy - INFO - CmdStan start processing
12:20:09 - cmdstanpy - INFO - Chain [1] start processing
12:20:09 - cmdstanpy - INFO - Chain [2] start processing
12:20:09 - cmdstanpy - INFO - Chain [3] start processing
12:20:09 - cmdstanpy - INFO - Chain [4] start processing
12:20:12 - cmdstanpy - INFO - Chain [3] done processing
12:20:12 - cmdstanpy - INFO - Chain [2] done processing
12:20:12 - cmdstanpy - INFO - Chain [1] done processing
12:20:12 - cmdstanpy - INFO - Chain [4] done processing
12:20:12 - cmdstanpy - WARNING - Non-fatal error during sampling:
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Exception: gp_exp_quad_cov: length_scale is 0, but must be positive! (in 'gp2.stan', line 57, column 2 to column 44)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Exception: cholesky_decompose: Matrix m is not positive definite (in 'gp2.stan', line 58, column 2 to column 58)
Consider re-running with show_console=True if the above output is unclear!
| Mean | MCSE | StdDev | MAD | 5% | 50% | 95% | ESS_bulk | ESS_tail | R_hat | |
|---|---|---|---|---|---|---|---|---|---|---|
| lp__ | -43.042800 | 0.032191 | 1.273380 | 1.072960 | -45.623200 | -42.734000 | -41.602800 | 1615.63 | 2309.47 | 1.001810 |
| l | 0.106153 | 0.000558 | 0.024983 | 0.022396 | 0.070874 | 0.103053 | 0.153686 | 2101.05 | 2029.25 | 1.003470 |
| s | 2.192220 | 0.019064 | 0.819682 | 0.589771 | 1.294680 | 1.997890 | 3.735240 | 2269.50 | 1881.09 | 1.001400 |
| nug | 0.730838 | 0.001140 | 0.057326 | 0.058693 | 0.643901 | 0.727256 | 0.827427 | 2575.18 | 2413.21 | 1.001580 |
| f[1] | 3.469980 | 0.007242 | 0.440289 | 0.441466 | 2.722780 | 3.472880 | 4.183230 | 3726.82 | 3836.11 | 0.999831 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| f[117] | -0.607229 | 0.034081 | 2.037360 | 1.833990 | -4.050970 | -0.529064 | 2.534820 | 3628.10 | 3736.43 | 0.999619 |
| f[118] | -0.564853 | 0.034749 | 2.086650 | 1.873460 | -4.042150 | -0.498740 | 2.639500 | 3649.49 | 3811.01 | 0.999575 |
| f[119] | -0.520251 | 0.035192 | 2.128040 | 1.901170 | -4.098390 | -0.455036 | 2.803440 | 3693.62 | 3637.82 | 0.999656 |
| f[120] | -0.474758 | 0.035421 | 2.162410 | 1.935030 | -4.140110 | -0.403482 | 2.907520 | 3764.51 | 3579.83 | 0.999793 |
| f[121] | -0.429516 | 0.035444 | 2.190850 | 1.916190 | -4.094190 | -0.381014 | 3.005500 | 3846.13 | 3581.04 | 0.999941 |
125 rows × 10 columns
array([ 3.46998, 3.57005, 3.61907, 3.61183, 3.54474, 3.41625, 3.22692, 2.9796 , 2.67934, 2.33323, 1.94999, 1.53955, 1.11248, 0.67938, 0.25037, -0.16535, -0.55993, -0.92697, -1.26164,
-1.56055, -1.82155, -2.04347, -2.22579, -2.36835, -2.47109, -2.53395, -2.55682, -2.53968, -2.4828 , -2.38702, -2.25396, -2.08616, -1.88706, -1.6609 , -1.41247, -1.14688, -0.86939, -0.58521,
-0.29953, -0.01756, 0.25538, 0.5138 , 0.75207, 0.96449, 1.14549, 1.28993, 1.39343, 1.45278, 1.46633, 1.43433, 1.3592 , 1.24557, 1.10014, 0.93141, 0.74912, 0.56368, 0.3855 ,
0.2243 , 0.0885 , -0.01538, -0.08321, -0.11358, -0.10792, -0.07035, -0.00732, 0.07309, 0.16196, 0.25052, 0.33116, 0.39821, 0.44857, 0.48193, 0.50069, 0.50941, 0.51416, 0.52155,
0.5379 , 0.56842, 0.61659, 0.68387, 0.76956, 0.87107, 0.9842 , 1.10373, 1.22392, 1.33897, 1.44339, 1.53228, 1.60137, 1.64709, 1.66656, 1.65762, 1.61893, 1.55011, 1.45194,
1.32642, 1.17686, 1.00774, 0.82447, 0.63302, 0.43955, 0.24995, 0.06956, -0.09716, -0.24674, -0.37674, -0.48575, -0.57327, -0.63959, -0.68571, -0.71317, -0.72393, -0.72023, -0.70444,
-0.67891, -0.64584, -0.60723, -0.56485, -0.52025, -0.47476, -0.42952])
Sta 663 - Spring 2025